import copy
from typing import Dict

from jsonargparse.typing import PositiveInt

from data_engine.utils.availability_utils import AvailabilityChecking
from data_engine.utils.constant import Fields
from data_engine.utils.mm_utils import SpecialTokens, remove_special_tokens
from data_engine.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Mapper

NAME = 'video_captioning_from_summarizer_mapper'
CHECK_PKGS = [
    'torch',
    'transformers',
    'simhash-pybind',  # by video caption
    'transformers_stream_generator',
    'einops',
    'accelerate',
    'tiktoken',  # by audio caption
    'torchaudio',  # by audio tag
    'git+https://github.com/xinyu1205/recognize-anything.git',  # by frame tag
]

with AvailabilityChecking(CHECK_PKGS, NAME):
    # video caption
    # audio caption
    import accelerate  # noqa: F401
    import einops  # noqa: F401
    # frame tag
    import ram  # noqa: F401
    import simhash  # noqa: F401
    import tiktoken  # noqa: F401
    import torch
    # audio tag
    import torchaudio  # noqa: F401
    import transformers  # noqa: F401
    import transformers_stream_generator  # noqa: F401

    # avoid hanging when calling clip in multiprocessing
    torch.set_num_threads(1)


@OPERATORS.register_module(NAME)
class VideoCaptioningFromSummarizerMapper(Mapper):
    """
    Mapper to generate video captions by summarizing several kinds of generated
    texts (captions from video/audio/frames, tags from audio/frames, ...)
    """

    _accelerator = 'cuda'
    _batched_op = True

    def __init__(self,
                 hf_summarizer: str = None,
                 trust_remote_code=False,
                 consider_video_caption_from_video: bool = True,
                 consider_video_caption_from_audio: bool = True,
                 consider_video_caption_from_frames: bool = True,
                 consider_video_tags_from_audio: bool = True,
                 consider_video_tags_from_frames: bool = True,
                 vid_cap_from_vid_args: Dict = None,
                 vid_cap_from_frm_args: Dict = None,
                 vid_tag_from_aud_args: Dict = None,
                 vid_tag_from_frm_args: Dict = None,
                 keep_tag_num: PositiveInt = 5,
                 keep_original_sample: bool = True,
                 *args,
                 **kwargs):
        """
        Initialization method.

        :param hf_summarizer: the summarizer model used to summarize texts
            generated by other methods.
        :param consider_video_caption_from_video: whether to consider the video
            caption generated from video directly in the summarization process.
            Default: True.
        :param consider_video_caption_from_audio: whether to consider the video
            caption generated from audio streams in the video in the
            summarization process. Default: True.
        :param consider_video_caption_from_frames: whether to consider the
            video caption generated from sampled frames from the video in the
            summarization process. Default: True.
        :param consider_video_tags_from_audio: whether to consider the video
            tags generated from audio streams in the video in the summarization
            process. Default: True.
        :param consider_video_tags_from_frames: whether to consider the video
            tags generated from sampled frames from the video in the
            summarization process. Default: True.
        :param vid_cap_from_vid_args: the arg dict for video captioning from
            video directly with keys are the arg names and values are the arg
            values. Default: None.
        :param vid_cap_from_frm_args: the arg dict for video captioning from
            sampled frames from the video with keys are the arg names and
            values are the arg values. Default: None.
        :param vid_tag_from_aud_args: the arg dict for video tagging from audio
            streams in the video with keys are the arg names and values are the
            arg values. Default: None.
        :param vid_tag_from_frm_args: the arg dict for video tagging from
            sampled frames from the video with keys are the arg names and
            values are the arg values. Default: None.
        :param keep_tag_num: max number N of tags from sampled frames to keep.
            Too many tags might bring negative influence to summarized text, so
            we consider to only keep the N most frequent tags. Default: 5.
        :param keep_original_sample: whether to keep the original sample. If
            it's set to False, there will be only summarized captions in the
            final datasets and the original captions will be removed. It's True
            in default.
        :param args: extra args
        :param kwargs: extra args
        """
        super().__init__(*args, **kwargs)

        self.keep_original_sample = keep_original_sample
        self.extra_args = kwargs

        # prepare summarizer
        self._hf_summarizer = hf_summarizer if hf_summarizer else 'mrm8488/flan-t5-large-finetuned-openai-summarize_from_feedback'  # noqa: E501
        self.model_key = prepare_model(
            model_type='huggingface',
            pretrained_model_name_or_path=self._hf_summarizer,
            trust_remote_code=trust_remote_code)

        # prepare input texts ops
        if vid_cap_from_vid_args is None:
            vid_cap_from_vid_args = {}
        if vid_cap_from_frm_args is None:
            vid_cap_from_frm_args = {}
        if vid_tag_from_aud_args is None:
            vid_tag_from_aud_args = {}
        if vid_tag_from_frm_args is None:
            vid_tag_from_frm_args = {}
        self.FIXED_ARGS = {
            'caption_num': 1,
            'keep_candidate_mode': 'random_any',
            'keep_original_sample': False,
        }
        self.cap_op_list = []
        self.tag_op_list = []
        if consider_video_caption_from_video:
            from .video_captioning_from_video_mapper import \
                VideoCaptioningFromVideoMapper
            self.cap_op_list.append(
                VideoCaptioningFromVideoMapper(**self._prepare_op_args(
                    VideoCaptioningFromVideoMapper, vid_cap_from_vid_args)))
        if consider_video_caption_from_audio:
            from .video_captioning_from_audio_mapper import \
                VideoCaptioningFromAudioMapper
            self.cap_op_list.append(
                VideoCaptioningFromAudioMapper(**self._prepare_op_args(
                    VideoCaptioningFromAudioMapper, {})))
        if consider_video_caption_from_frames:
            from .video_captioning_from_frames_mapper import \
                VideoCaptioningFromFramesMapper
            self.cap_op_list.append(
                VideoCaptioningFromFramesMapper(**self._prepare_op_args(
                    VideoCaptioningFromFramesMapper, vid_cap_from_frm_args)))
        if consider_video_tags_from_audio:
            from .video_tagging_from_audio_mapper import \
                VideoTaggingFromAudioMapper
            self.tag_op_list.append(
                VideoTaggingFromAudioMapper(**self._prepare_op_args(
                    VideoTaggingFromAudioMapper, vid_tag_from_aud_args)))
        if consider_video_tags_from_frames:
            from .video_tagging_from_frames_mapper import \
                VideoTaggingFromFramesMapper
            self.tag_op_list.append(
                VideoTaggingFromFramesMapper(**self._prepare_op_args(
                    VideoTaggingFromFramesMapper, vid_tag_from_frm_args)))

        self.keep_tag_num = keep_tag_num

    def _prepare_op_args(self, op_class, args_dict):
        required_args = set(op_class.__init__.__code__.co_varnames)
        args_dict.update(self.FIXED_ARGS)
        temp_args = copy.deepcopy(args_dict)
        for key in temp_args:
            if key not in required_args:
                args_dict.pop(key)
        args_dict['accelerator'] = self.accelerator
        return args_dict

    def _process_single_sample(self, sample, rank=None):
        # there is no video in this sample
        if self.video_key not in sample or not sample[self.video_key]:
            return []

        # there is no activated ops
        if len(self.cap_op_list) == 0 and len(self.tag_op_list) == 0:
            return []

        # get paths of all video(s)
        loaded_video_keys = sample[self.video_key]

        # get models
        model, tokenizer = get_model(self.model_key, rank, self.use_cuda())

        captioned_sample = copy.deepcopy(sample)
        # generate for each video chunk by chunk
        captioned_texts = ''
        offset = 0
        for chunk in sample[self.text_key].split(SpecialTokens.eoc):
            # skip empty chunks
            if not chunk.strip():
                continue

            vid_count = chunk.count(SpecialTokens.video)

            if vid_count == 0:
                # add special tokens
                captioned_texts += f'{chunk}{SpecialTokens.eoc}'
                continue

            # make a temporary sample
            temp_sample = {
                self.text_key: chunk,
                self.video_key: loaded_video_keys[offset:offset + vid_count],
            }

            captioned_text_list = []
            # tag ops
            for op in self.tag_op_list:
                temp_sample = op.process(temp_sample, rank=rank)
            if Fields.video_audio_tags in temp_sample:
                captioned_text_list.extend(
                    temp_sample[Fields.video_audio_tags])
            if Fields.video_frame_tags in temp_sample:
                for tag_list in temp_sample[Fields.video_frame_tags]:
                    captioned_text_list.extend(tag_list[self.keep_tag_num])
            # cap ops
            for op in self.cap_op_list:
                captioned_text_list.append(
                    remove_special_tokens(
                        op._process_single_sample(temp_sample,
                                                  rank=rank)[0]['text']))

            # summarization
            all_texts = ', '.join(captioned_text_list)
            input_ids = tokenizer(all_texts, return_tensors='pt').input_ids.to(
                model.device)
            outputs = model.generate(input_ids, max_new_tokens=128)
            summarized_text = tokenizer.decode(outputs[0],
                                               skip_special_tokens=True)

            offset += vid_count
            captioned_text = f'{SpecialTokens.video * vid_count} ' \
                             f'{summarized_text}'

            # add special tokens
            captioned_texts += f'{captioned_text}{SpecialTokens.eoc}'

        captioned_sample[self.text_key] = captioned_texts
        return [captioned_sample]

    def process(self, samples, rank=None):
        # reconstruct samples from "dict of lists" to "list of dicts"
        reconstructed_samples = []
        for i in range(len(samples[self.text_key])):
            reconstructed_samples.append(
                {key: samples[key][i]
                 for key in samples})
        samples_after_split = []
        # do split for each sample within the batch
        for ori_sample in reconstructed_samples:
            if self.keep_original_sample:
                samples_after_split.append(ori_sample)
            generated_samples = self._process_single_sample(ori_sample,
                                                            rank=rank)
            if len(generated_samples) != 0:
                samples_after_split.extend(generated_samples)
        # reconstruct samples from "list of dicts" to "dict of lists"
        keys = samples_after_split[0].keys()
        res_samples = {}
        for key in keys:
            res_samples[key] = [s[key] for s in samples_after_split]

        return res_samples
